# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from math import sqrt
from memory import LegacyUnsafePointer as UnsafePointer, bitcast
from sys import size_of

import linalg.matmul.vendor.blas as vendor_blas
from gpu import WARP_SIZE, barrier
from gpu import lane_id as get_lane_id
from gpu.cluster import block_rank_in_cluster
from gpu.host import DeviceContext, FuncAttribute
from gpu.host.nvidia.tma import TensorMapSwizzle
from gpu import block_idx, lane_id, thread_idx, warp_id as get_warp_id
from gpu.memory import external_memory
from gpu.mma_sm100 import *
from gpu.tcgen05 import *
from layout import Layout, LayoutTensor
from layout._fillers import random
from layout._utils import ManagedLayoutTensor
from layout.int_tuple import IntTuple
from layout.tensor_core_async import (
    tile_layout_k_major,
    tile_layout_mn_major,
    tile_to_descriptor,
)
from layout.tma_async import SharedMemBarrier, TMATensorTile, create_tma_tile
from testing import assert_almost_equal

from utils.index import Index, IndexList
from utils.numerics import get_accum_type, max_finite, min_finite
from utils.static_tuple import StaticTuple


fn cpu_matmul_naive[
    *, transpose_a: Bool, transpose_b: Bool
](C: LayoutTensor, A: LayoutTensor, B: LayoutTensor):
    comptime M = C.layout[0].size()
    comptime N = C.layout[1].size()
    # layout_a is M x K
    comptime layout_a = A.layout.transpose() if transpose_a else A.layout
    # layout_b is K x N
    comptime layout_b = B.layout.transpose() if transpose_b else B.layout
    comptime K = layout_a[1].size()
    __comptime_assert M == layout_a[0].size(), String(
        "C.M = ", M, "; A.M = ", layout_a[0].size()
    )
    __comptime_assert N == layout_b[1].size(), String(
        "C.N = ", M, "; B.N = ", layout_b[1].size()
    )
    __comptime_assert K == layout_b[0].size(), String(
        "A.K = ", K, "; B.K = ", layout_b[0].size()
    )
    for n in range(N):
        for m in range(M):
            var acc: Float32 = 0.0
            for k in range(K):
                var a_idx: Int

                @parameter
                if transpose_a:
                    a_idx = k * M + m
                else:
                    a_idx = m * K + k
                var b_idx: Int

                @parameter
                if transpose_b:
                    b_idx = n * K + k
                else:
                    b_idx = k * N + n
                acc += (
                    A.ptr.load(a_idx).cast[DType.float32]()
                    * B.ptr.load(b_idx).cast[DType.float32]()
                )
            c_idx = m * N + n
            C.ptr.store(c_idx, acc.cast[C.dtype]())


@__llvm_metadata(`nvvm.cluster_dim`=cluster_shape)
@__llvm_arg_metadata(a_tma_op, `nvvm.grid_constant`)
@__llvm_arg_metadata(b_tma_op, `nvvm.grid_constant`)
fn tma_umma_kernel_ss[
    a_type: DType,
    b_type: DType,
    c_type: DType,
    a_layout: Layout,
    b_layout: Layout,
    c_layout: Layout,
    a_desc_layout: Layout,
    b_desc_layout: Layout,
    block_tile_shape: IndexList[3],
    mma_shape: IndexList[3],
    transpose_a: Bool = False,
    transpose_b: Bool = True,
    cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1),
    a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE,
    b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE,
    num_threads: UInt = 128,
](
    a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout],
    b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout],
    c: LayoutTensor[c_type, c_layout, MutAnyOrigin],
    num_iters: UInt,
):
    constrained[num_threads == 128 or num_threads == 256]()
    constrained[
        a_type == b_type and a_type in (DType.float8_e4m3fn, DType.bfloat16),
        (
            "a_type and b_type must be the same and either float8_e4m3fn or"
            " bfloat16"
        ),
    ]()

    comptime BM = block_tile_shape[0]
    comptime BN = block_tile_shape[1]
    comptime BK = block_tile_shape[2]
    comptime MMA_M = mma_shape[0]
    comptime MMA_N = mma_shape[1]
    comptime MMA_K = mma_shape[2]
    comptime num_m_mmas = BM // MMA_M
    comptime num_n_mmas = BN // MMA_N
    comptime num_k_mmas = BK // MMA_K

    comptime a_k_major = not transpose_a
    comptime b_k_major = transpose_b
    comptime a_smem_layout = tile_layout_k_major[
        a_type, BM, BK, swizzle_mode=a_swizzle
    ]() if a_k_major else tile_layout_mn_major[
        a_type, BM, BK, swizzle_mode=a_swizzle
    ]()
    comptime b_smem_layout = tile_layout_k_major[
        b_type, BN, BK, swizzle_mode=b_swizzle
    ]() if b_k_major else tile_layout_mn_major[
        b_type, BN, BK, swizzle_mode=b_swizzle
    ]()

    a_smem = rebind[
        UnsafePointer[Scalar[a_type], address_space = AddressSpace.SHARED]
    ](
        external_memory[
            Scalar[a_type],
            address_space = AddressSpace.SHARED,
            alignment=128,
            name="tmem_test_dynamic_shared_memory",
        ]()
    )
    comptime a_smem_tile_t = LayoutTensor[
        a_type,
        a_smem_layout,
        MutAnyOrigin,
        address_space = AddressSpace.SHARED,
        alignment=128,
    ]
    comptime b_smem_tile_t = LayoutTensor[
        b_type,
        b_smem_layout,
        MutAnyOrigin,
        address_space = AddressSpace.SHARED,
        alignment=128,
    ]

    comptime a_size = a_smem_layout.size()
    comptime b_size = b_smem_layout.size()

    constrained[
        ((a_size * size_of[a_type]()) % 128) == 0, "preserve alignment"
    ]()
    constrained[
        ((b_size * size_of[b_type]()) % 16) == 0, "preserve alignment"
    ]()
    var b_smem = (a_smem + a_size).bitcast[Scalar[b_type]]()

    var a_smem_tile = a_smem_tile_t(a_smem)
    var b_smem_tile = b_smem_tile_t(b_smem)

    # Shared memory pointer to hold tensor memory address
    var ptr_tmem_addr = (b_smem + b_size).bitcast[UInt32]()

    comptime accum_type = get_accum_type[a_type]()

    comptime c_frag_size = MMA_M * MMA_N // Int(num_threads)
    var c_frag = SIMD[accum_type, c_frag_size]()

    comptime a_expected_bytes = a_size * size_of[a_type]()
    comptime b_expected_bytes = b_size * size_of[b_type]()
    comptime expected_bytes = a_expected_bytes + b_expected_bytes

    tma_mbar = (ptr_tmem_addr + 2).bitcast[SharedMemBarrier]()
    mma_mbar = tma_mbar + 1

    if thread_idx.x == 0:
        tma_mbar[0].init()
        mma_mbar[0].init()

    var tma_phase: UInt32 = 0
    var mma_phase: UInt32 = 0

    var elect_one_warp = get_warp_id() == 0
    var elect_one_thread = thread_idx.x == 0
    var elect_one_cta = block_rank_in_cluster() % 2 == 0
    comptime max_tmem_cols = 512

    if elect_one_warp:
        tcgen05_alloc[1](ptr_tmem_addr, max_tmem_cols)

    # Ensure all threads sees initialized mbarrier and
    # tensor memory allocation
    barrier()

    tmem_addr = ptr_tmem_addr[0]

    @parameter
    if num_threads > 128:
        if thread_idx.x >= 128:
            tmem_addr += 1 << 20  # offset for lane 16

    comptime a_canonical_layout = tile_to_descriptor[
        a_type, a_smem_layout, is_k_major=a_k_major
    ]()
    comptime b_canonical_layout = tile_to_descriptor[
        b_type, b_smem_layout, is_k_major=b_k_major
    ]()
    comptime a_stride01 = a_canonical_layout[0].stride[1].value()
    comptime a_stride11 = a_canonical_layout[1].stride[1].value()
    comptime aSBO = (
        a_stride01 if a_k_major
        or a_swizzle == TensorMapSwizzle.SWIZZLE_NONE else a_stride11
    ) * size_of[a_type]()
    comptime aLBO = (
        a_stride11 if a_k_major
        or a_swizzle == TensorMapSwizzle.SWIZZLE_NONE else a_stride01
    ) * size_of[a_type]()
    comptime b_stride01 = b_canonical_layout[0].stride[1].value()
    comptime b_stride11 = b_canonical_layout[1].stride[1].value()
    comptime bSBO = (
        b_stride01 if b_k_major
        or b_swizzle == TensorMapSwizzle.SWIZZLE_NONE else b_stride11
    ) * size_of[b_type]()
    comptime bLBO = (
        b_stride11 if b_k_major
        or b_swizzle == TensorMapSwizzle.SWIZZLE_NONE else b_stride01
    ) * size_of[b_type]()

    adesc = MMASmemDescriptor.create[aSBO, aLBO, a_swizzle](a_smem_tile.ptr)
    bdesc = MMASmemDescriptor.create[bSBO, bLBO, b_swizzle](b_smem_tile.ptr)

    comptime mma_kind = UMMAKind.KIND_F8F6F4 if a_type == DType.float8_e4m3fn else UMMAKind.KIND_F16
    idesc = UMMAInsDescriptor[mma_kind].create[
        accum_type,
        a_type,
        b_type,
        Index[dtype = DType.uint32](mma_shape[0], mma_shape[1]),
        transpose_a=transpose_a,
        transpose_b=transpose_b,
    ]()

    for i in range(num_iters):
        if elect_one_thread:
            tma_mbar[0].expect_bytes(expected_bytes)

            var m = block_idx.y * UInt(BM)
            var n = block_idx.x * UInt(BN)
            var k = i * UInt(BK)
            a_tma_op.async_copy(
                a_smem_tile,
                tma_mbar[0],
                (m, k) if transpose_a else (k, m),
            )
            b_tma_op.async_copy(
                b_smem_tile,
                tma_mbar[0],
                (k, n) if transpose_b else (n, k),
            )

        tma_mbar[0].wait(tma_phase)
        tma_phase ^= 1

        if elect_one_thread:
            if i == 0:
                mma[c_scale=0](adesc, bdesc, tmem_addr, idesc)

                @parameter
                for j in range(1, num_k_mmas):
                    comptime idx = IntTuple(0, MMA_K * j)
                    comptime a_offset = a_smem_layout(idx) * size_of[a_type]()
                    comptime b_offset = b_smem_layout(idx) * size_of[b_type]()
                    mma[c_scale=1](
                        adesc + a_offset, bdesc + b_offset, tmem_addr, idesc
                    )
            else:

                @parameter
                for j in range(num_k_mmas):
                    comptime idx = IntTuple(0, MMA_K * j)
                    comptime a_offset = a_smem_layout(idx) * size_of[a_type]()
                    comptime b_offset = b_smem_layout(idx) * size_of[b_type]()
                    mma[c_scale=1](
                        adesc + a_offset, bdesc + b_offset, tmem_addr, idesc
                    )

            mma_arrive(mma_mbar)

        mma_mbar[0].wait(mma_phase)
        mma_phase ^= 1

    c_frag = tcgen05_ld[
        datapaths=16,
        bits=256,
        repeat = BN // 8,
        dtype=accum_type,
        pack=False,
        width=c_frag_size,
    ](tmem_addr)

    tcgen05_load_wait()

    if elect_one_warp:
        tcgen05_release_allocation_lock[1]()
        tcgen05_dealloc[1](tmem_addr, max_tmem_cols)

    comptime num_warps = num_threads // UInt(WARP_SIZE)
    var warp_id = get_warp_id()

    @parameter
    if num_threads > 128:
        warp_id = UInt(2 * Int(warp_id % 4) + Int(warp_id // 4))

    ctile = c.tile[BM, BN](Int(block_idx.y), Int(block_idx.x))

    @parameter
    for m_mma in range(num_m_mmas):

        @parameter
        for n_mma in range(num_n_mmas):
            comptime mma_id = n_mma * num_m_mmas + m_mma

            c_gmem_warp_tile = ctile.tile[MMA_M // Int(num_warps), MMA_N](
                4 * m_mma + Int(warp_id), n_mma
            )

            c_gmem_frag = c_gmem_warp_tile.vectorize[1, 2]().distribute[
                Layout.row_major(8, 4)
            ](lane_id())

            comptime num_vecs_m = c_gmem_frag.layout.shape[0].value()
            comptime num_vecs_n = c_gmem_frag.layout.shape[1].value()

            @parameter
            for n_vec in range(num_vecs_n):

                @parameter
                for m_vec in range(num_vecs_m):
                    comptime i_vec = n_vec * num_vecs_m + m_vec

                    c_gmem_frag[m_vec, n_vec] = rebind[
                        c_gmem_frag.element_type
                    ](
                        SIMD[accum_type, 2](
                            c_frag[2 * i_vec], c_frag[2 * i_vec + 1]
                        ).cast[c_type]()
                    )


@__llvm_arg_metadata(b_tma_op, `nvvm.grid_constant`)
fn tma_umma_kernel_ts[
    a_type: DType,
    b_type: DType,
    c_type: DType,
    a_layout: Layout,
    b_layout: Layout,
    c_layout: Layout,
    b_desc_layout: Layout,
    block_tile_shape: IndexList[3],
    mma_shape: IndexList[3],
    transpose_b: Bool = True,
    b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE,
    num_threads: UInt = 128,
](
    a: LayoutTensor[a_type, a_layout, MutAnyOrigin],
    b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout],
    c: LayoutTensor[c_type, c_layout, MutAnyOrigin],
    num_iters: UInt,
):
    constrained[num_threads == 128 or num_threads == 256]()
    comptime BM = block_tile_shape[0]
    comptime BN = block_tile_shape[1]
    comptime BK = block_tile_shape[2]
    comptime MMA_M = mma_shape[0]
    comptime MMA_N = mma_shape[1]
    comptime MMA_K = mma_shape[2]
    comptime num_m_mmas = BM // MMA_M
    comptime num_n_mmas = BN // MMA_N

    constrained[
        num_m_mmas == 1 and num_n_mmas == 1,
        "num_m_mmas and num_n_mmas must be 1",
    ]()
    constrained[
        a_type == b_type and a_type == DType.bfloat16,
        "a_type and b_type must be the same and bfloat16 type",
    ]()
    comptime b_smem_layout = tile_layout_k_major[
        b_type, BN, BK, swizzle_mode=b_swizzle
    ]() if transpose_b else tile_layout_mn_major[
        b_type, BN, BK, swizzle_mode=b_swizzle
    ]()

    b_smem = rebind[
        UnsafePointer[Scalar[b_type], address_space = AddressSpace.SHARED]
    ](
        external_memory[
            Scalar[b_type],
            address_space = AddressSpace.SHARED,
            alignment=128,
            name="tmem_test_dynamic_shared_memory",
        ]()
    )
    comptime b_smem_tile_t = LayoutTensor[
        b_type,
        b_smem_layout,
        MutAnyOrigin,
        address_space = AddressSpace.SHARED,
        alignment=128,
    ]

    var b_smem_tile = b_smem_tile_t(b_smem)
    comptime b_size = b_smem_tile_t.layout.size()

    comptime accum_type = get_accum_type[a_type]()

    constrained[
        ((b_size * size_of[b_type]()) % 16) == 0, "preserve alignment"
    ]()
    # Shared memory pointer to hold tensor memory address
    var ptr_tmem_addr = (b_smem + b_size).bitcast[UInt32]()

    comptime c_frag_size = MMA_M * MMA_N // Int(num_threads)
    var c_frag = SIMD[accum_type, c_frag_size]()

    comptime b_expected_bytes = b_size * size_of[b_type]()
    comptime expected_bytes = b_expected_bytes

    tma_mbar = (ptr_tmem_addr + 2).bitcast[SharedMemBarrier]()
    mma_mbar = tma_mbar + 1

    if thread_idx.x == 0:
        tma_mbar[0].init()
        mma_mbar[0].init()

    var tma_phase: UInt32 = 0
    var mma_phase: UInt32 = 0

    var elect_one_warp = get_warp_id() == 0
    var elect_one_thread = thread_idx.x == 0
    comptime max_tmem_cols = 512

    if elect_one_warp:
        tcgen05_alloc[1](ptr_tmem_addr, max_tmem_cols)

    # Ensure all threads sees initialized mbarrier and
    # tensor memory allocation
    barrier()

    var tmem_addr = ptr_tmem_addr[0]

    @parameter
    if num_threads > 128:
        if thread_idx.x >= 128:
            tmem_addr += 1 << 20  # offset for lane 16
    var c_tmem: UInt32 = tmem_addr
    var a_tmem: UInt32 = tmem_addr + MMA_N

    comptime b_canonical_layout = tile_to_descriptor[
        b_type, b_smem_layout, is_k_major=transpose_b
    ]()
    comptime b_stride01 = b_canonical_layout[0].stride[1].value()
    comptime b_stride11 = b_canonical_layout[1].stride[1].value()
    comptime bSBO = (b_stride01 if transpose_b else b_stride11) * size_of[
        b_type
    ]()
    comptime bLBO = (b_stride11 if transpose_b else b_stride01) * size_of[
        b_type
    ]()

    bdesc = MMASmemDescriptor.create[bSBO, bLBO, b_swizzle](b_smem_tile.ptr)

    idesc = UMMAInsDescriptor[UMMAKind.KIND_F16].create[
        accum_type,
        a_type,
        b_type,
        Index[dtype = DType.uint32](mma_shape[0], mma_shape[1]),
        transpose_b=transpose_b,
    ]()

    comptime num_warps = num_threads // UInt(WARP_SIZE)
    var warp_id = get_warp_id()

    @parameter
    if num_threads > 128:
        warp_id = UInt(2 * Int(warp_id % 4) + Int(warp_id // 4))

    comptime a_frag_size = BM * BK * size_of[a_type]() // 4 // Int(num_threads)
    var a_frag = SIMD[DType.uint32, a_frag_size]()

    for i in range(num_iters):
        # Load A from global memory to registers.
        # Each thread loads 32 values
        a_gmem_tile = a.tile[BM, BK](Int(block_idx.y), Int(i))
        a_gmem_warp_tile = a_gmem_tile.tile[BM // Int(num_warps), BK](
            Int(warp_id), 0
        )
        # Vectorize by 4 for 16x256 load, each thread loads multiple vector
        # of size 2x4B=4xBF16
        a_gmem_frag = a_gmem_warp_tile.vectorize[1, 4]().distribute[
            Layout.row_major(8, 4)
        ](get_lane_id())
        comptime num_vecs_m = a_gmem_frag.layout.shape[0].value()
        comptime num_vecs_k = a_gmem_frag.layout.shape[1].value()

        @parameter
        for k in range(num_vecs_k):

            @parameter
            for j in range(num_vecs_m):
                vec = a_gmem_frag[j, k]
                comptime idx = k * num_vecs_m + j
                a_frag[2 * idx] = bitcast[DType.uint32, 1](vec.split()[0])
                a_frag[2 * idx + 1] = bitcast[DType.uint32, 1](vec.split()[1])

        tcgen05_st[
            datapaths=16,
            bits=256,
            repeat = BK * size_of[a_type]() // 4 // 8,
            pack=False,
        ](a_tmem, a_frag)

        # store_wait synchronizes within a warp. One warp could go ahead
        # while other warps are still storing to tmem.
        tcgen05_store_wait()
        barrier()

        # Load B by TMA
        if elect_one_thread:
            tma_mbar[0].expect_bytes(expected_bytes)

            b_tma_op.async_copy(
                b_smem_tile,
                tma_mbar[0],
                (i * UInt(BK), block_idx.x * UInt(BN)) if transpose_b else (
                    block_idx.x * UInt(BN),
                    i * UInt(BK),
                ),
            )

        # Sync TMA and tcgen05_st because the latter can sync across warps.
        tma_mbar[0].wait(tma_phase)
        tma_phase ^= 1

        if elect_one_thread:
            comptime atmem_kstride = mma_shape[2] // 2  # * size_of[a_type]()
            if i == 0:
                mma[c_scale=0](a_tmem, bdesc, c_tmem, idesc)

                @parameter
                for j in range(1, BK // mma_shape[2]):
                    comptime b_idx = IntTuple(MMA_N * 0, MMA_K * j)
                    comptime b_offset = b_smem_layout(b_idx) * size_of[b_type]()
                    mma[c_scale=1](
                        a_tmem + j * atmem_kstride,
                        bdesc + b_offset,
                        c_tmem,
                        idesc,
                    )
            else:

                @parameter
                for j in range(BK // mma_shape[2]):
                    comptime b_idx = IntTuple(MMA_N * 0, MMA_K * j)
                    comptime b_offset = b_smem_layout(b_idx) * size_of[b_type]()
                    mma[c_scale=1](
                        a_tmem + j * atmem_kstride,
                        bdesc + b_offset,
                        c_tmem,
                        idesc,
                    )

            mma_arrive(mma_mbar)

        mma_mbar[0].wait(mma_phase)
        mma_phase ^= 1

    # Each thread owns a row in c tile. This is inefficient but to
    # test the instruction shape.
    c_frag = tcgen05_ld[
        datapaths=16,
        bits=256,
        repeat = BN // 8,
        dtype=accum_type,
        pack=False,
        width=c_frag_size,
    ](c_tmem)

    tcgen05_load_wait()

    if elect_one_warp:
        tcgen05_release_allocation_lock[1]()
        tcgen05_dealloc[1](tmem_addr, max_tmem_cols)

    ctile = c.tile[BM, BN](Int(block_idx.y), Int(block_idx.x))

    @parameter
    for m_mma in range(num_m_mmas):

        @parameter
        for n_mma in range(num_n_mmas):
            comptime mma_id = n_mma * num_m_mmas + m_mma

            c_gmem_warp_tile = ctile.tile[MMA_M // Int(num_warps), MMA_N](
                4 * m_mma + Int(warp_id), n_mma
            )

            c_gmem_frag = c_gmem_warp_tile.vectorize[1, 2]().distribute[
                Layout.row_major(8, 4)
            ](lane_id())

            comptime num_vecs_m = c_gmem_frag.layout.shape[0].value()
            comptime num_vecs_n = c_gmem_frag.layout.shape[1].value()

            @parameter
            for n_vec in range(num_vecs_n):

                @parameter
                for m_vec in range(num_vecs_m):
                    comptime i_vec = n_vec * num_vecs_m + m_vec

                    c_gmem_frag[m_vec, n_vec] = rebind[
                        c_gmem_frag.element_type
                    ](
                        SIMD[accum_type, 2](
                            c_frag[2 * i_vec], c_frag[2 * i_vec + 1]
                        ).cast[c_type]()
                    )


def test_tma_umma[
    a_type: DType,
    b_type: DType,
    c_type: DType,
    prob_shape: IndexList[3],
    block_tile_shape: IndexList[3],
    mma_shape: IndexList[3],
    transpose_a: Bool = False,
    transpose_b: Bool = True,
    cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1),
    a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE,
    b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE,
    a_smem: Bool = True,
    cta_group: Int = 1,
](ctx: DeviceContext):
    comptime BM = block_tile_shape[0]
    comptime BN = block_tile_shape[1]
    comptime BK = block_tile_shape[2]

    comptime MMA_M = mma_shape[0]
    comptime MMA_N = mma_shape[1]
    comptime MMA_K = mma_shape[2]

    print(
        "mma_"
        + ("s" if a_smem else "t")
        + "s_"
        + String(a_type)
        + "_"
        + String(b_type)
        + "_"
        + String(c_type)
        + " problem shape "
        + String(prob_shape)
        + " block tile "
        + String(block_tile_shape)
        + " transa="
        + String(transpose_a)
        + " transb="
        + String(transpose_b)
        + "; inst shape "
        + String(mma_shape)
        + " A "
        + (String(a_swizzle) if a_smem else "tmem")
        + " B "
        + String(b_swizzle)
    )

    comptime M = prob_shape[0]
    comptime N = prob_shape[1]
    comptime K = prob_shape[2]

    var a = ManagedLayoutTensor[
        a_type,
        Layout.row_major(K, M) if transpose_a else Layout.row_major(M, K),
    ](ctx)

    var a_extreme: Float32 = sqrt(
        sqrt(max_finite[a_type]().cast[DType.float32]())
    )
    random(
        a.tensor[update=False](),
        min=(-a_extreme).cast[a_type](),
        max=a_extreme.cast[a_type](),
    )

    comptime b_layout = Layout.row_major(
        N, K
    ) if transpose_b else Layout.row_major(K, N)
    var b = ManagedLayoutTensor[b_type, b_layout](ctx)
    var b_col_major = ManagedLayoutTensor[b_type, Layout.row_major(N, K)](ctx)

    var b_extreme: Float32 = sqrt(
        sqrt(max_finite[b_type]().cast[DType.float32]())
    )
    random(
        b.tensor[update=False](),
        min=(-b_extreme).cast[b_type](),
        max=b_extreme.cast[b_type](),
    )

    var c = ManagedLayoutTensor[
        c_type,
        Layout.row_major(M, N),
    ](ctx)

    var c_ref = ManagedLayoutTensor[
        c_type,
        Layout.row_major(M, N),
    ](ctx)

    a_tma_op = create_tma_tile[
        Index(BK, BM) if transpose_a else Index(BM, BK),
        is_k_major = not transpose_a,
        swizzle_mode=a_swizzle,
    ](ctx, a.device_tensor())
    b_tma_op = create_tma_tile[
        Index(BN, BK) if transpose_b else Index(BK, BN),
        is_k_major=transpose_b,
        swizzle_mode=b_swizzle,
    ](ctx, b.device_tensor())

    comptime block_dim = UInt(2 * MMA_M)

    @parameter
    if a_smem:
        comptime smem_use = (BM + BN) * size_of[a_type]() * BK + 24
        comptime kernel = tma_umma_kernel_ss[
            a_type,
            b_type,
            c_type,
            type_of(a_tma_op).layout,
            type_of(b_tma_op).layout,
            Layout.row_major(M, N),
            type_of(a_tma_op).desc_layout,
            type_of(b_tma_op).desc_layout,
            block_tile_shape,
            mma_shape,
            transpose_a=transpose_a,
            transpose_b=transpose_b,
            cluster_shape=cluster_shape,
            a_swizzle=a_swizzle,
            b_swizzle=b_swizzle,
            num_threads=block_dim,
        ]
        ctx.enqueue_function_checked[kernel, kernel](
            a_tma_op,
            b_tma_op,
            c.device_tensor(),
            UInt(K // BK),
            grid_dim=(N // BN, M // BM),
            block_dim=(block_dim),
            shared_mem_bytes=Int(smem_use),
            func_attribute=FuncAttribute.MAX_DYNAMIC_SHARED_SIZE_BYTES(
                smem_use
            ),
        )

    else:
        comptime smem_use = BN * size_of[b_type]() * BK + 24
        comptime kernel = tma_umma_kernel_ts[
            a_type,
            b_type,
            c_type,
            Layout.row_major(M, K),
            type_of(b_tma_op).layout,
            Layout.row_major(M, N),
            type_of(b_tma_op).desc_layout,
            block_tile_shape,
            mma_shape,
            transpose_b=transpose_b,
            b_swizzle=b_swizzle,
            num_threads=block_dim,
        ]

        ctx.enqueue_function_checked[kernel, kernel](
            a.device_tensor(),
            b_tma_op,
            c.device_tensor(),
            UInt(K // BK),
            grid_dim=(N // BN, M // BM),
            block_dim=(block_dim),
            shared_mem_bytes=Int(smem_use),
            func_attribute=FuncAttribute.MAX_DYNAMIC_SHARED_SIZE_BYTES(
                smem_use
            ),
        )

    @parameter
    if a_type == DType.float8_e4m3fn and (not transpose_b):
        # NOTE: Matrix B should always be in col-major layout for cublasLt to work
        var b_host_col_major = b_col_major.tensor()
        var b_tensor = b.tensor()
        for i in range(N):
            for j in range(K):
                b_host_col_major[i, j] = b_tensor[j, i]

        vendor_blas.matmul(
            ctx,
            c_ref.device_tensor[update=False](),
            a.device_tensor[update=False](),
            b_col_major.device_tensor[update=True](),
            c_row_major=True,
            transpose_b=True,
        )

    elif M >= 64 and N >= 64 and K >= 64:
        vendor_blas.matmul(
            ctx,
            c_ref.device_tensor[update=False](),
            a.device_tensor[update=False](),
            b.device_tensor[update=False](),
            c_row_major=True,
            transpose_a=transpose_a,
            transpose_b=transpose_b,
        )
    else:
        cpu_matmul_naive[transpose_a=transpose_a, transpose_b=transpose_b](
            c_ref.tensor[update=False](),
            a.tensor[update=False](),
            b.tensor[update=False](),
        )
        _ = c_ref.device_tensor()  # update host

    ctx.synchronize()

    c_host = c.tensor()
    c_host_ref = c_ref.tensor()

    for m in range(M):
        for n in range(N):
            # Increased tolerance for FP8/bfloat16 accumulation errors
            # FP8/bf16 matrix multiplication can have larger numerical errors
            # due to reduced precision in intermediate accumulations
            assert_almost_equal(
                c_host[m, n],
                c_host_ref[m, n],
                atol=0.01,
                rtol=0.01,
                msg=String(m) + ", " + String(n),
            )
            # print(m, n, c_host[m, n], c_host_ref[m, n])

    _ = a^
    _ = b^
    _ = b_col_major^
    _ = c^
    _ = c_ref^


def main():
    with DeviceContext() as ctx:

        @parameter
        for dtype in [DType.bfloat16, DType.float8_e4m3fn]:

            @parameter
            for swizzle in [TensorMapSwizzle.SWIZZLE_128B]:

                @parameter
                for BK_scale in range(0, 2):
                    comptime BK = (swizzle.bytes() // size_of[dtype]()) * (
                        1 + BK_scale
                    )

                    @parameter
                    for mma_size_scale in range(0, 2):
                        comptime MMA_M = 64 * (1 + mma_size_scale)
                        comptime MMA_K = 32 if dtype == DType.float8_e4m3fn else 16

                        @parameter
                        for size_scale in range(1, 3):

                            @parameter
                            for transpose_b in range(0, 2):
                                test_tma_umma[
                                    dtype,
                                    dtype,
                                    DType.bfloat16,
                                    Index(
                                        MMA_M * size_scale,
                                        128 * size_scale,
                                        BK * size_scale,
                                    ),
                                    Index(MMA_M, 128, BK),
                                    Index(MMA_M, 128, MMA_K),
                                    a_swizzle=swizzle,
                                    b_swizzle=swizzle,
                                    transpose_b = Bool(transpose_b),
                                ](ctx)

                                @parameter
                                if dtype == DType.bfloat16:
                                    test_tma_umma[
                                        dtype,
                                        dtype,
                                        DType.bfloat16,
                                        Index(
                                            MMA_M * size_scale,
                                            128 * size_scale,
                                            BK * size_scale,
                                        ),
                                        Index(MMA_M, 128, BK),
                                        Index(MMA_M, 128, MMA_K),
                                        b_swizzle=swizzle,
                                        transpose_b = Bool(transpose_b),
                                        a_smem=False,
                                    ](ctx)

                                    test_tma_umma[
                                        dtype,
                                        dtype,
                                        DType.bfloat16,
                                        Index(
                                            MMA_M * size_scale,
                                            128 * size_scale,
                                            BK * size_scale,
                                        ),
                                        Index(MMA_M, 128, BK),
                                        Index(MMA_M, 128, MMA_K),
                                        a_swizzle=swizzle,
                                        b_swizzle=swizzle,
                                        transpose_a=True,
                                        transpose_b = Bool(transpose_b),
                                    ](ctx)

                                    test_tma_umma[
                                        dtype,
                                        dtype,
                                        DType.bfloat16,
                                        Index(
                                            MMA_M * size_scale,
                                            128 * size_scale,
                                            BK * size_scale,
                                        ),
                                        Index(MMA_M, 128, BK),
                                        Index(MMA_M, 128, MMA_K),
                                        a_swizzle=swizzle,
                                        b_swizzle = TensorMapSwizzle.SWIZZLE_NONE,
                                        transpose_a=True,
                                        transpose_b = Bool(transpose_b),
                                    ](ctx)
                                    test_tma_umma[
                                        dtype,
                                        dtype,
                                        DType.bfloat16,
                                        Index(
                                            MMA_M * size_scale,
                                            128 * size_scale,
                                            BK * size_scale,
                                        ),
                                        Index(MMA_M, 128, BK),
                                        Index(MMA_M, 128, MMA_K),
                                        a_swizzle = TensorMapSwizzle.SWIZZLE_NONE,
                                        b_swizzle = TensorMapSwizzle.SWIZZLE_NONE,
                                        transpose_a=True,
                                        transpose_b = Bool(transpose_b),
                                    ](ctx)

        @parameter
        for size_scale in range(1, 3):

            @parameter
            for transpose_a in range(0, 2):

                @parameter
                for transpose_b in range(0, 2):
                    test_tma_umma[
                        DType.bfloat16,
                        DType.bfloat16,
                        DType.bfloat16,
                        Index(size_scale * 64, 8, 16),
                        Index(size_scale * 64, 8, 16),
                        Index(size_scale * 64, 8, 16),
                        a_swizzle = TensorMapSwizzle.SWIZZLE_NONE,
                        b_swizzle = TensorMapSwizzle.SWIZZLE_NONE,
                        transpose_a = Bool(transpose_a),
                        transpose_b = Bool(transpose_b),
                    ](ctx)
